import re
import os
from typing import Any, List, Dict, Tuple
from collections import defaultdict
from pipelines.prompta.libs.query import DefaultQuery
from pipelines.prompta.utils import query2str, tuple2word, word2tuple
from prompta.core.language import BaseLanguage
from .probabilistic_abstract_oracle import ProbabilisticAbstractOracle


def extract_and_call_function(llm_response):
    """
    Extract Python code from LLM response, evaluate it to get the function,
    and call the function with the input string.
    
    Args:
        llm_response: String containing the LLM response with markdown code blocks
        
    Returns:
        Result of calling the extracted function with input_str
    """
    # Extract the program code
    pattern = r"```(?:python|Python)[\n\r]+(.*?)[\n\r]+```"
    matches = re.findall(pattern, llm_response, re.DOTALL)
    
    if not matches:
        # Try finding any code blocks if no Python-specific blocks are found
        pattern = r"```[\n\r]+(.*?)[\n\r]+```"
        matches = re.findall(pattern, llm_response, re.DOTALL)
    
    if not matches:
        raise ValueError("No code blocks found in LLM response")
    
    # Join all code blocks
    program = "\n\n".join(matches)
    program = program.replace(": Tuple[str, ...]", "")
    
    # Create a namespace to store the function
    namespace = {}
    
    # Execute the program in the namespace
    try:
        exec(program, namespace)
        func_name = re.findall(r"def\s+(\w+)\(", program)[0]
        # Return the first function found in the namespace
        for item in namespace.values():
            if callable(item) and item.__name__.startswith(func_name):
                return item
        raise ValueError("No function found in the code")
    except Exception as e:
        print(program)
        raise RuntimeError(f"Error executing extracted function: {str(e)}")


class CodeBasedOracle(ProbabilisticAbstractOracle):
    def __init__(self, language: BaseLanguage, *args: Any, **kwargs: Any) -> None:
        super().__init__(language, *args, **kwargs)
        self.program_path = os.path.join(os.path.dirname(__file__), '..', 'test', 'programs', language.ctx_name + '.txt')
        self.program = open(self.program_path, 'r').read()
        self.function = extract_and_call_function(self.program)

    def reset(self, language: BaseLanguage, exp_dir: str, alphabet=None, load_history=False):
        super().reset(language, exp_dir, alphabet, load_history)
        self.program_path = os.path.join(os.path.dirname(__file__), '..', 'test', 'programs', language.ctx_name + '.txt')
        self.program = open(self.program_path, 'r').read()
        self.function = extract_and_call_function(self.program)

    def __call__(self, input_str: str, *args: Any, **kwargs: Any) -> Any:
        try:
            return self.function(input_str)
        except Exception as e:
            print(f"Error calling function: {str(e)}")
            return False
